from typing import Iterable, Optional
from torch.nn.modules.module import Module
from model.extractor import vanilla_extractor
from decoder.LSTM_decoder import vanilla_LSTM
from dagma.linear import DagmaLinear
from model.hsic import HSIC

from model.lyapunov_spectrum import extract_cell_states, lyapunov_solve_unknown

from sklearn.metrics import r2_score
from torch import nn

import torch
import numpy as np

class vanilla_model(nn.Module):
    def __init__(self, input_dim, low_dim, drop_out, h_dim, lstm_layer, window_size, pos_dim):
        super(vanilla_model, self).__init__()
        self.aligner = None
        self.extractor = vanilla_extractor(input_dim, low_dim, drop_out)
        self.decoder = vanilla_LSTM(low_dim, h_dim, lstm_layer, window_size, pos_dim)
    
    def forward(self, src_x, src_y):
        x_latent = self.extractor(src_x)
        y_pred, mse_loss = self.decoder(x_latent, src_y)
        return y_pred, mse_loss

class VAE_Model(nn.Module):
    def __init__(self, input_dim, low_dim, drop_out, h_dim, lstm_layer, latent_dim, pos_dim,
                 kld_weight_rec, kld_weight_pos, rec_weight):
        super(VAE_Model, self).__init__()

        # Variables
        self.input_dim = input_dim
        self.low_dim = low_dim
        self.drop_out = drop_out
        self.h_dim = h_dim
        self.lstm_layer = lstm_layer
        self.latent_dim = latent_dim
        self.pos_dim = pos_dim

        # loss
        self.mse_regression = nn.MSELoss()
        self.poisson_criterion = nn.PoissonNLLLoss(log_input=False)
        self.kld_weight_rec, self.kld_weight_pos = kld_weight_rec, kld_weight_pos
        self.rec_weight = rec_weight

        # Low-D Readin
        self.read_in = vanilla_extractor(self.input_dim, self.low_dim, drop_out)
        
        # Encoder Structure
        self.encoder = nn.LSTM(input_size=self.low_dim, hidden_size=self.h_dim, num_layers=self.lstm_layer, batch_first=True)
        for name, param in self.encoder.named_parameters():
            if len(param.shape) > 1:
                nn.init.xavier_uniform_(param,0.1)
        
        # Expectation and std
        self.fc_mu = nn.Sequential(nn.BatchNorm1d(self.h_dim),
                                   nn.Linear(self.h_dim, self.latent_dim),
                                   nn.Dropout(self.drop_out))
        self.fc_std = nn.Sequential(nn.BatchNorm1d(self.h_dim),
                                    nn.Linear(self.h_dim, self.latent_dim),
                                    nn.Dropout(self.drop_out))

        # cursor position Readout
        read_out_dim = int(self.latent_dim)
        self.pos_read_out = nn.Sequential(nn.BatchNorm1d(read_out_dim),
                                          nn.Linear(read_out_dim, self.pos_dim),
                                          nn.Dropout(self.drop_out))
        # self.y_pos_read

        # Reconstruction
        self.decoder = nn.LSTM(input_size=self.latent_dim, hidden_size=self.latent_dim, num_layers=self.lstm_layer, batch_first=True)
        for name, param in self.decoder.named_parameters():
            if len(param.shape) > 1:
                nn.init.xavier_uniform_(param,0.1)
        
        self.recon_layer = nn.Sequential(nn.BatchNorm1d(self.latent_dim),
                                         nn.Linear(self.latent_dim, self.h_dim),
                                         nn.ReLU(),
                                         nn.Dropout(self.drop_out),
                                         nn.BatchNorm1d(self.h_dim),
                                         nn.Linear(self.h_dim, self.input_dim),
                                         nn.Softplus()
                                         )
        
    
    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, src_x, src_y, train_flag=False):
        # Readin
        x_low = self.read_in(src_x)
        src_hid, (_, src_feature) = self.encoder(x_low)
        x_hid = src_feature[0]
        # x_hid = torch.squeeze(src_hid[:, -1, :])
        
        # Sampling
        mu, log_var = self.fc_mu(x_hid), self.fc_std(x_hid)
        z = self.reparameterize(mu, log_var) if train_flag else mu

        # Reconstruction
        batch_size = src_hid.shape[0]
        src_hid = torch.reshape(src_hid, (-1, self.h_dim))
        mu_rec, log_var_rec = self.fc_mu(src_hid), self.fc_std(src_hid)
        mu_rec, log_var_rec = torch.reshape(mu_rec, (batch_size, -1, self.latent_dim)), torch.reshape(log_var_rec, (batch_size, -1, self.latent_dim))

        z_rec = self.reparameterize(mu_rec, log_var_rec) if train_flag else mu_rec
        re_sp, _ = self.decoder(z_rec)

        re_sp = torch.reshape(re_sp, (-1, self.latent_dim))
        re_sp = self.recon_layer(re_sp)
        re_sp = torch.reshape(re_sp, (batch_size, -1, self.input_dim)) # Poisson NLL
        
        # Readout
        pos_latent = z
        y_pred = self.pos_read_out(pos_latent)

        # Loss
        mse_loss = self.mse_regression(y_pred, src_y)
        rec_loss = self.poisson_criterion(re_sp, src_x)

        # KL divergence: KL(N(mu, std**2)||N(0, 1))
        kld_loss_rec = torch.mean(0.5 * (- log_var_rec + mu_rec ** 2 + log_var_rec.exp() - 1))
        kld_loss_pos = torch.mean(0.5 * (- log_var + mu ** 2 + log_var.exp() - 1))

        total_loss = mse_loss + self.rec_weight*rec_loss + self.kld_weight_rec*kld_loss_rec + self.kld_weight_pos*kld_loss_pos

        return y_pred, mse_loss, total_loss



class VAE_Disentangler(nn.Module):
    def __init__(self, h_dim, latent_dim, drop_out):
        super(VAE_Disentangler, self).__init__()

        self.h_dim = h_dim
        self.latent_dim = latent_dim
        self.drop_out = drop_out

        # Expectation and std
        self.fc_mu = nn.Sequential(nn.BatchNorm1d(self.h_dim),
                                   nn.Linear(self.h_dim, self.latent_dim),
                                   nn.Dropout(self.drop_out))
        self.fc_std = nn.Sequential(nn.BatchNorm1d(self.h_dim),
                                    nn.Linear(self.h_dim, self.latent_dim),
                                    nn.Dropout(self.drop_out))
    
    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input, train_flag=False):
        mu, log_var = self.fc_mu(input), self.fc_std(input)
        z = self.reparameterize(mu=mu, logvar=log_var) if train_flag else mu
        return z, mu, log_var

class spike_Reconstrcuter(nn.Module):
    def __init__(self, latent_dim, lstm_layer, h_dim, drop_out, input_dim):
        super(spike_Reconstrcuter, self).__init__()

        # Variables
        self.latent_dim = latent_dim
        self.lstm_layer = lstm_layer
        self.h_dim = h_dim
        self.drop_out = drop_out
        self.input_dim = input_dim

        self.decoder = nn.LSTM(input_size=self.latent_dim, hidden_size=self.latent_dim, num_layers=self.lstm_layer, batch_first=True)
        for name, param in self.decoder.named_parameters():
            if len(param.shape) > 1:
                nn.init.xavier_uniform_(param,0.1)
        '''
        self.recon_layer = nn.Sequential(nn.BatchNorm1d(self.latent_dim),
                                         nn.Linear(self.latent_dim, self.h_dim),
                                         nn.ReLU(),
                                         nn.Dropout(self.drop_out),
                                         nn.BatchNorm1d(self.h_dim),
                                         nn.Linear(self.h_dim, self.input_dim),
                                         nn.Softplus()
                                         )
        '''
        self.recon_layer = nn.Sequential(nn.BatchNorm1d(self.latent_dim),
                                         nn.Linear(self.latent_dim, self.h_dim),
                                         nn.Dropout(self.drop_out)
                                        )

    def forward(self, input):
        '''
        re_sp, _ = self.decoder(input)
        batch_size = re_sp.shape[0]

        re_sp = torch.reshape(re_sp, (-1, self.latent_dim))
        re_sp = self.recon_layer(re_sp)
        re_sp = torch.reshape(re_sp, (batch_size, -1, self.input_dim)) # Poisson NLL
        '''

        re_sp = self.recon_layer(input)

        return re_sp

class domain_Discriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim, drop_out, output_dim):
        super(domain_Discriminator, self).__init__()

        # Variables
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.drop_out = drop_out
        self.output_dim = output_dim

        self.classifier = nn.Sequential(
            nn.BatchNorm1d(self.input_dim),
            nn.Linear(self.input_dim, self.hidden_dim),
            nn.Dropout(self.drop_out),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.Dropout(self.drop_out),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.output_dim),
            nn.Sigmoid()
        )

        # loss 
        self.bce_criterion = nn.BCELoss()
    
    def forward(self, input):
        label_prob = self.classifier(input)
        return label_prob
    
    

class disentangle_VAE_Model(nn.Module):
    def __init__(self, input_dim, low_dim, drop_out, h_dim, lstm_layer, latent_dim, pos_dim,
                 kld_weight_rec, kld_weight_pos, rec_weight, mse_weight=1.0, domain_weight=1.0):
        super(disentangle_VAE_Model, self).__init__()

        # Variables
        self.input_dim = input_dim
        self.low_dim = low_dim
        self.drop_out = drop_out
        self.h_dim = h_dim
        self.lstm_layer = lstm_layer
        self.latent_dim = latent_dim
        self.pos_dim = pos_dim
        self.kld_weight_rec = kld_weight_rec
        self.kld_weight_pos = kld_weight_pos
        self.rec_weight = rec_weight
        self.mse_weight = mse_weight
        self.domain_weight = domain_weight

        # loss
        self.mse_regression = nn.MSELoss()
        self.poisson_criterion = nn.PoissonNLLLoss(log_input=False)
        self.kld_weight_rec, self.kld_weight_pos = kld_weight_rec, kld_weight_pos
        self.rec_weight = rec_weight

        # Low-D Readin
        self.read_in_src = vanilla_extractor(self.input_dim, self.low_dim, drop_out)
        # self.read_in_tgt = vanilla_extractor(self.input_dim, self.low_dim, drop_out)
        for name, param in self.read_in_src.named_parameters():
            if len(param.shape) > 1:
                nn.init.xavier_uniform_(param,0.1)
        
        # Encoder Structure
        self.encoder = nn.LSTM(input_size=self.low_dim, hidden_size=self.h_dim, num_layers=self.lstm_layer, batch_first=True)
        for name, param in self.encoder.named_parameters():
            if len(param.shape) > 1:
                nn.init.xavier_uniform_(param,0.1)

        # VAE Disentangler
        '''
        self.domain_VAE = VAE_Model(input_dim=self.input_dim, low_dim=self.low_dim, drop_out=self.drop_out, 
                                    h_dim=self.h_dim, lstm_layer=self.lstm_layer, 
                                    latent_dim=self.latent_dim, pos_dim=self.pos_dim,
                                    kld_weight_rec=self.kld_weight_rec, kld_weight_pos=self.kld_weight_pos,
                                    rec_weight=self.rec_weight)
        '''
        # self.domain_VAE = VAE_Disentangler(h_dim=self.h_dim, latent_dim=self.latent_dim, drop_out=self.drop_out)
        self.behavior_VAE = VAE_Disentangler(h_dim=self.h_dim, latent_dim=self.latent_dim, drop_out=self.drop_out)

        # domain classifiers
        self.domain_classifier = domain_Discriminator(input_dim=int(self.latent_dim/2),
                                                      hidden_dim=int(self.latent_dim/2),
                                                      drop_out=self.drop_out,
                                                      output_dim=1)
        self.behavior_classifier = domain_Discriminator(input_dim=int(self.latent_dim/2),
                                                        hidden_dim=int(self.latent_dim/2),
                                                        drop_out=self.drop_out,
                                                        output_dim=1)

        '''
        self.meta_reweight = nn.Sequential(
            nn.BatchNorm1d(self.latent_dim),
            nn.Linear(self.latent_dim, self.latent_dim),
            nn.Dropout(self.drop_out),
            nn.ReLU()
        )
        '''
        # cursor position Readout
        self.pos_read_out = nn.Sequential(nn.BatchNorm1d(int(self.latent_dim/2)),
                                          nn.Linear(int(self.latent_dim/2), self.pos_dim),
                                          nn.Dropout(self.drop_out)
                                          )
        # self.y_pos_read

        # Reconstruction
        '''
        self.decoder = nn.LSTM(input_size=self.latent_dim, hidden_size=self.latent_dim, num_layers=self.lstm_layer, batch_first=True)
        for name, param in self.decoder.named_parameters():
            if len(param.shape) > 1:
                nn.init.xavier_uniform_(param,0.1)
        
        self.recon_layer = nn.Sequential(nn.BatchNorm1d(self.latent_dim),
                                         nn.Linear(self.latent_dim, self.h_dim),
                                         nn.ReLU(),
                                         nn.Dropout(self.drop_out),
                                         nn.BatchNorm1d(self.h_dim),
                                         nn.Linear(self.h_dim, self.input_dim),
                                         nn.Softplus()
                                         )
        '''

        self.domain_Reconstructor_src = spike_Reconstrcuter(latent_dim=self.latent_dim, lstm_layer=self.lstm_layer, 
                                                            h_dim=self.h_dim, drop_out=self.drop_out, input_dim=self.input_dim)

        '''
        self.behavior_Reconstructor_src = spike_Reconstrcuter(latent_dim=self.latent_dim, lstm_layer=self.lstm_layer, 
                                                              h_dim=self.h_dim, drop_out=self.drop_out, input_dim=self.input_dim)
        '''
        
        '''
        self.domain_Reconstructor_tgt = spike_Reconstrcuter(latent_dim=self.latent_dim, lstm_layer=self.lstm_layer, 
                                                            h_dim=self.h_dim, drop_out=self.drop_out, input_dim=self.input_dim)
        self.behavior_Reconstructor_tgt = spike_Reconstrcuter(latent_dim=self.latent_dim, lstm_layer=self.lstm_layer, 
                                                              h_dim=self.h_dim, drop_out=self.drop_out, input_dim=self.input_dim)
        '''
        # mutual information estimator
        self.mi_Estimator = HSIC
    
    '''
    def get_behavior_causal_structrue(self, pos_latent, pos_gt, lambda1=0.02, warm_iter=3e3, max_iter=1e4, w_threshold=1e-3):
        DAG_node = np.concatenate((pos_latent, pos_gt), axis=1)
        # calculate DAG structure
        DAG_model = DagmaLinear(loss_type='l2') # create a linear model with least squares loss (continuous)
        W_est = DAG_model.fit(DAG_node, lambda1=lambda1, warm_iter=warm_iter, max_iter=max_iter, w_threshold=w_threshold)
        return W_est
    '''

    def get_z_latent(self, src_x, train_flag=False, dis_flag=True):
        # Readin
        x_low = self.read_in_src(src_x) 
        # src_hid, (src_feature, _) = self.encoder(x_low)
        src_hid, (src_hid_final, src_feature) = self.encoder(x_low)
        
        # lstm final cell state
        x_hid = src_feature[0]
        
        VAE_Dis = self.behavior_VAE
        # Sampling
        z, mu, log_var = VAE_Dis(x_hid, train_flag)

        '''
        # Reconstruction
        batch_size = src_hid.shape[0]
        src_hid = torch.reshape(src_hid, (-1, self.h_dim))
        z_rec, mu_rec, log_var_rec = VAE_Dis(src_hid, train_flag)
        z_rec = torch.reshape(z_rec, (batch_size, -1, self.latent_dim))# reshape
        mu_rec, log_var_rec = torch.reshape(mu_rec, (batch_size, -1, self.latent_dim)), torch.reshape(log_var_rec, (batch_size, -1, self.latent_dim))
        '''
        
        # KL divergence: KL(N(mu, std**2)||N(0, 1))
        # kld_loss_rec = torch.mean(0.5 * (- log_var_rec + mu_rec ** 2 + log_var_rec.exp() - 1))
        kld_loss_pos = torch.mean(0.5 * (- log_var + mu ** 2 + log_var.exp() - 1))

        if not dis_flag:
            z = x_hid
        return z, kld_loss_pos, x_hid
    
    '''
    def get_reconstructed_spike(self, z_rec):
        re_sp, _ = self.decoder(z_rec)
        batch_size = re_sp.shape[0]

        re_sp = torch.reshape(re_sp, (-1, self.latent_dim))
        re_sp = self.recon_layer(re_sp)
        re_sp = torch.reshape(re_sp, (batch_size, -1, self.input_dim)) # Poisson NLL

        return re_sp
    '''

    def forward(self, src_x, src_y, src_flag=True, domain_flag=False, train_flag=False, dis_flag=True):
        # torch.mul(): element-wise multi
        z, kld_loss_pos, x_hid = self.get_z_latent(src_x, train_flag, dis_flag)
        
        '''
        _, z_rec_other, _, _ = self.get_z_latent(src_x, src_flag, not domain_flag, train_flag)
        if domain_flag:
            z_rec = torch.cat((z_rec, z_rec_other), axis=-1)
        else:
            z_rec = torch.cat((z_rec_other, z_rec), axis=-1)
        
        # Reconstruction
        re_sp, _ = self.decoder(z_rec)
        batch_size = re_sp.shape[0]

        re_sp = torch.reshape(re_sp, (-1, self.latent_dim))
        re_sp = self.recon_layer(re_sp)
        re_sp = torch.reshape(re_sp, (batch_size, -1, self.input_dim)) # Poisson NLL
        '''
        # Reconstruction
        if dis_flag:
            re_sp = self.domain_Reconstructor_src(z)

        # Readout (behavior)
        pos_latent = z[:, 0:int(self.latent_dim/2)] if dis_flag else z
        # reweight feature dimension

        # pos_latent = self.meta_reweight(pos_latent)
        '''
        pos_latent_clone = pos_latent.detach().clone()
        src_y_clone = src_y.detach().clone()
        if torch.cuda.is_available():
            pos_latent_clone = pos_latent_clone.cpu().numpy()
            src_y_clone = src_y.cpu().numpy()
        W_est = self.get_behavior_causal_structrue(pos_latent=pos_latent_clone, pos_gt=src_y_clone)
        '''
        y_pred = self.pos_read_out(pos_latent)

        # Loss
        mse_loss = self.mse_regression(y_pred, src_y)
        rec_loss = self.mse_regression(re_sp, x_hid) if dis_flag else 0
        # rec_loss = self.poisson_criterion(re_sp, src_x)
        
        total_loss = mse_loss + self.rec_weight*rec_loss + self.kld_weight_pos*kld_loss_pos

        return y_pred, mse_loss, total_loss

    def get_model_pos_latent(self, src_x, domain_flag, train_flag):
        z, _, _, _ = self.get_z_latent(src_x, domain_flag, train_flag)

        # Readout (behavior)
        pos_latent = z

        return pos_latent

def train_disentangle_VAE_Model(model, device, src_x, src_y, tgt_x, grl_weight, hsic_weight, optimizer_VAE, optimizer_ds, optimizer_br, dis_flag=True):
    train_flag = True
    # domain disentanglement
    domain_flag = True
    src_flag = True


    z_src, kld_loss_pos_src, x_hid_src = model.get_z_latent(src_x, train_flag, dis_flag)
    z_tgt, kld_loss_pos_tgt, x_hid_tgt = model.get_z_latent(tgt_x, train_flag, dis_flag)

    # behavior relevant
    latent_dim = z_src.shape[1]
    z_br_src, z_ds_src = z_src[:, 0:int(latent_dim/2)], z_src[:, int(latent_dim/2):]
    z_br_tgt, z_ds_tgt = z_tgt[:, 0:int(latent_dim/2)], z_tgt[:, int(latent_dim/2):]
    if not dis_flag:
        z_br_src, z_ds_src = z_src, z_src
        z_br_tgt, z_ds_tgt = z_tgt, z_tgt

    # train domain_Discriminator
    for name, param in model.named_parameters():
        if name.__contains__("domain_classifier"):
            param.requires_grad = True
    # source
    # z_latent_src = torch.cat((z_ds_src, z_br_src), axis=0).detach() # fix VAE parameters
    z_latent_ds_src = z_ds_src.detach()
    label_prob_src = model.domain_classifier(z_latent_ds_src)
    src_snum = label_prob_src.shape[0]
    label_src_gt = torch.ones(size=(src_snum, 1)).to(device)
    dis_loss_src = model.domain_classifier.bce_criterion(label_prob_src, label_src_gt)

    z_latent_br_src = z_br_src.detach()
    label_prob_br_src = model.behavior_classifier(z_latent_br_src)
    dis_loss_br_src = model.behavior_classifier.bce_criterion(label_prob_br_src, label_src_gt)

    # target
    # z_latent_tgt = torch.cat((z_ds_tgt, z_br_tgt), axis=0).detach()
    z_latent_tgt = z_ds_tgt.detach()
    label_prob_tgt = model.domain_classifier(z_latent_tgt)
    tgt_snum = label_prob_tgt.shape[0]
    label_tgt_gt = torch.zeros(size=(tgt_snum, 1)).to(device)
    dis_loss_tgt = model.domain_classifier.bce_criterion(label_prob_tgt, label_tgt_gt)

    z_latent_br_tgt = z_br_tgt.detach()
    label_prob_br_tgt = model.behavior_classifier(z_latent_br_tgt)
    dis_loss_br_tgt = model.behavior_classifier.bce_criterion(label_prob_br_tgt, label_tgt_gt)

    dis_loss_total = dis_loss_src + dis_loss_tgt
    optimizer_ds.zero_grad()
    dis_loss_total.backward(retain_graph=True)
    optimizer_ds.step()

    print("domain Discriminator Loss = %f" % (dis_loss_total))

    dis_loss_br_total = dis_loss_br_src + dis_loss_br_tgt
    optimizer_br.zero_grad()
    dis_loss_br_total.backward(retain_graph=True)
    optimizer_br.step()

    print("behavior Discriminator Loss = %f" % (dis_loss_br_total))

    # train VAE_Disentangler
    optimizer_VAE.zero_grad()
    # source domain
    # cursor position prediction
    pos_latent_src = z_br_src
    y_pred_br_src = model.pos_read_out(pos_latent_src)
    mse_loss_br_src = model.mse_regression(y_pred_br_src, src_y)
    
    # reconstruction
    # z_rec_src = torch.cat((z_rec_br_src, z_rec_ds_src.detach()), axis=-1)
    if dis_flag:
        recon_br_src = model.domain_Reconstructor_src(z_src)
        rec_loss_br_src = model.mse_regression(recon_br_src, x_hid_src)

    '''
    z_rec_src = torch.cat((z_rec_br_src.detach(), z_rec_ds_src), axis=-1)
    recon_ds_src = model.domain_Reconstructor_src(z_rec_src)
    rec_loss_ds_src = model.poisson_criterion(recon_ds_src, src_x)
    '''

    # loss_br_src = 0
    loss_br_src = model.rec_weight*rec_loss_br_src + model.kld_weight_pos*kld_loss_pos_src if dis_flag else 0
    if model.mse_weight > 0:
        loss_br_src += model.mse_weight*mse_loss_br_src
    # kld_weight_rec_src, kld_weight_pos_src = 1e-2, 1e-3
    # loss_ds_src = model.rec_weight*rec_loss_ds_src + model.kld_weight_pos*kld_loss_rec_ds_src + model.kld_weight_rec*kld_loss_pos_ds_src
    
    # target domain
    # reconstruction
    # z_rec_tgt = torch.cat((z_rec_br_tgt, z_rec_ds_tgt.detach()), axis=-1)
    if dis_flag:
        recon_br_tgt = model.domain_Reconstructor_src(z_tgt)
        rec_loss_br_tgt = model.mse_regression(recon_br_tgt, x_hid_tgt)
    # Chewie, Mihili, Jango
    # rec_weight = 1e-2
    # Spike, rec_weight = 1e-4
    # RT, Mihili, rec_weight = 1e-2
    # rec_weight = 1e-2
    # kld_weight_pos = 1e-4
    loss_br_tgt = model.rec_weight*rec_loss_br_tgt + model.kld_weight_rec*kld_loss_pos_tgt if dis_flag else 0
    # loss_br_tgt = rec_weight*rec_loss_br_tgt + model.kld_weight_rec*kld_loss_rec_br_tgt + model.kld_weight_pos*kld_loss_pos_br_tgt

    '''
    z_rec_tgt = torch.cat((z_rec_br_tgt.detach(), z_rec_ds_tgt), axis=-1)
    recon_ds_tgt = model.domain_Reconstructor_src(z_rec_tgt)
    rec_loss_ds_tgt = model.poisson_criterion(recon_ds_tgt, tgt_x)
    loss_ds_tgt = model.rec_weight*rec_loss_ds_tgt + kld_weight_rec_src*kld_loss_rec_ds_tgt + kld_weight_pos_src*kld_loss_pos_ds_tgt
    '''

    
    # adversarial learning (q(z_br|x) alignment)
    # reverse gradient (domain_Discriminator)
    for name, param in model.named_parameters():
        if name.__contains__("domain_classifier"):
            param.requires_grad = False
    
    label_src_gt = torch.ones(size=(z_br_src.shape[0], 1)).to(device)
    label_tgt_gt = torch.zeros(size=(z_ds_src.shape[0], 1)).to(device)

    # br features
    label_prob_br_src_after = model.behavior_classifier(z_br_src)
    label_prob_br_tgt_after = model.behavior_classifier(z_br_tgt)
    
    loss_br_domain = model.domain_classifier.bce_criterion(label_prob_br_tgt_after, label_tgt_gt) - model.domain_classifier.bce_criterion(label_prob_br_src_after, label_src_gt)

    # ds features
    label_prob_ds_src_after = model.domain_classifier(z_ds_src)
    label_prob_ds_tgt_after = model.domain_classifier(z_ds_tgt)

    loss_ds_domain = model.domain_classifier.bce_criterion(label_prob_ds_tgt_after, label_tgt_gt) + model.domain_classifier.bce_criterion(label_prob_ds_src_after, label_src_gt)
    if dis_flag:
        loss_ds_domain = 0
    
    # mutual information regularization (HSIC)
    # z_br_latent = torch.cat((z_br_src, z_br_tgt), axis=0).to(device)
    # z_ds_latent = torch.cat((z_ds_src, z_ds_tgt), axis=0).to(device)

    # hsic_reg_ds = model.mi_Estimator(z_br_latent.detach(), z_ds_latent)

    # loss_ds_domain = 0
    # hsic_reg_ds = 0
    total_loss_ds = 0
    if model.domain_weight > 0:
        total_loss_ds += model.domain_weight*loss_ds_domain
    '''
    if hsic_weight > 0:
        total_loss_ds += hsic_weight*hsic_reg_ds
    '''

    # total_loss_ds.backward()

    # hsic_reg_br = model.mi_Estimator(z_br_latent, z_ds_latent.detach())

    # grl_weight = 1e-3
    # loss_br_tgt = 0
    # hsic_reg_br = 0
    total_loss_br = loss_br_src + loss_br_tgt
    # total_loss_br = loss_br_src + loss_br_tgt
    '''
    if hsic_weight > 0:
        total_loss_br += hsic_weight*hsic_reg_br
    '''

    if grl_weight > 0:
        total_loss_br -= grl_weight*loss_br_domain
    total_loss_br.backward()
    
    src_pred_pos, src_gt_pos = y_pred_br_src.detach().cpu().numpy(), src_y.detach().cpu().numpy()
    r2_score_src = r2_score(src_gt_pos, src_pred_pos)
    print("total Loss = %f, r2 score = %f" % (total_loss_br, r2_score_src))

    optimizer_VAE.step()

    return